# -*- coding:utf-8 -*- 
# Syscall module: building ropchains performing syscalls
from ropgenerator.IO import string_special, banner, string_bold, error
from enum import Enum
import ropgenerator.exploit.syscalls.Linux32 as Linux32
import ropgenerator.exploit.syscalls.Linux64 as Linux64
from ropgenerator.exploit.syscalls.SyscallDef import build_syscall_Linux
from ropgenerator.exploit.Utils import parseFunction, parse_bad_bytes, parse_keep_regs
import ropgenerator.Architecture as Arch
from ropgenerator.Constraints import Constraint, Assertion, BadBytes, RegsNotModified
from ropgenerator.Load import loadedBinary
from ropgenerator.semantic.Engine import getBaseAssertion
from ropgenerator.CommonUtils import set_offset, reset_offset

#####################
# Available systems #
#####################
sysLinux32 = "LINUX32"
sysLinux64 = "LINUX64"

availableSystems = [sysLinux32, sysLinux64]

###################
# SYSCALL COMMAND #
###################

OPTION_OUTPUT = '--output-format'
OPTION_OUTPUT_SHORT = '-f'
# Options for output
OUTPUT_CONSOLE = 'console'
OUTPUT_PYTHON = 'python'
OUTPUT_RAW = 'raw'
OUTPUT = None # The one choosen 

OPTION_BAD_BYTES = '--bad-bytes'
OPTION_BAD_BYTES_SHORT = '-b'

OPTION_KEEP_REGS = '--keep-regs'
OPTION_KEEP_REGS_SHORT = '-k'

OPTION_LIST = "--list"
OPTION_LIST_SHORT = "-l"

OPTION_FUNCTION = "--call"
OPTION_FUNCTION_SHORT = "-c" 

OPTION_OFFSET="--offset"
OPTION_OFFSET_SHORT = "-off"

OPTION_HELP = "--help"
OPTION_HELP_SHORT = "-h"

CMD_SYSCALL_HELP =  banner([string_bold("'syscall' command"),\
                    string_special("(Call system functions with ROPChains)")])
CMD_SYSCALL_HELP += "\n\n\t"+string_bold("Usage:")+\
"\n\t\tsyscall [OPTIONS]"
CMD_SYSCALL_HELP += "\n\n\t"+string_bold("Options")+":"
CMD_SYSCALL_HELP += "\n\t\t"+string_special(OPTION_FUNCTION_SHORT)+","+\
    string_special(OPTION_FUNCTION)+" <function>\t Call a system function"
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_BAD_BYTES_SHORT)+","+string_special(OPTION_BAD_BYTES)+" <bytes>\t Bad bytes for payload.\n\t\t\t\t\t Expected format is a list of bytes \n\t\t\t\t\t separated by comas (e.g '-b 0A,0B,2F')"
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_KEEP_REGS_SHORT)+","+string_special(OPTION_KEEP_REGS)+" <regs>\t Registers that shouldn't be modified.\n\t\t\t\t\t Expected format is a list of registers \n\t\t\t\t\t separated by comas (e.g '-k edi,eax')"
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_OFFSET_SHORT)+","+\
    string_special(OPTION_OFFSET)+" <int>\t Offset to add to gadget addresses"   
    
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_LIST_SHORT)+","+\
    string_special(OPTION_LIST)+" [<system>]\t List supported functions"
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_OUTPUT_SHORT)+","+\
    string_special(OPTION_OUTPUT)+\
    " <fmt> Output format for ropchains.\n\t\t\t\t\t Expected format is one of the\n\t\t\t\t\t following: "+\
    string_special(OUTPUT_CONSOLE)+','+string_special(OUTPUT_PYTHON)
    
CMD_SYSCALL_HELP += "\n\n\t\t"+string_special(OPTION_HELP_SHORT)+","+string_special(OPTION_HELP)+"\t\t Show this help"
CMD_SYSCALL_HELP += "\n\n\t"+string_bold("Supported systems")+": "+', '.join([string_special(s) for s in availableSystems])
CMD_SYSCALL_HELP += "\n\n\t"+string_bold("Function format")+": "+\
    string_special("function")+"( "+string_special("arg1")+","+string_special(" ...")+\
    "," + string_special(" argN")+")"
CMD_SYSCALL_HELP += "\n\n\t"+string_bold("Examples")+": "+\
    "\n\t\tsyscall -f python -c mprotect(0x123456, 200, 7)"+\
    "\n\t\tsyscall -c execve('/bin/sh', 0, 0)"+\
    "\n\t\tsyscall -l LINUX64"
def print_help():
    print(CMD_SYSCALL_HELP)


def syscall(args):
    global OUTPUT, OUTPUT_CONSOLE, OUTPUT_PYTHON
    # Parsing arguments
    if( not args):
        print_help()
        return
    OUTPUT = OUTPUT_CONSOLE
    funcName = None
    offset = 0
    i = 0
    seenOutput = False
    seenFunction = False
    seenBadBytes = False
    seenKeepRegs = False
    seenOffset = False
    
    constraint = Constraint()
    assertion = getBaseAssertion()
    while i < len(args):
        if( args[i] in [OPTION_LIST, OPTION_LIST_SHORT]):
            listSyscalls(args[1:])
            return 
        if( args[i] in [OPTION_BAD_BYTES, OPTION_BAD_BYTES_SHORT]):
            if( seenBadBytes ):
                error("Error. '" + args[i] + "' option should be used only once")
                return 
            if( i+1 >= len(args)):
                error("Error. Missing bad bytes after option '"+args[i]+"'")
                return 
            seenBadBytes = True
            (success, res) = parse_bad_bytes(args[i+1])
            if( not success ):
                error(res)
                return
            i = i+2
            constraint = constraint.add(BadBytes(res))
        elif( args[i] in [OPTION_KEEP_REGS, OPTION_KEEP_REGS_SHORT]):
            if( seenKeepRegs ):
                error("Error. '" + args[i] + "' option should be used only once")
                return 
            if( i+1 >= len(args)):
                error("Error. Missing register after option '"+args[i]+"'")
                return 
            seenKeepRegs = True
            (success, res) = parse_keep_regs(args[i+1])
            if( not success ):
                error(res)
                return
            i = i+2
            constraint = constraint.add(RegsNotModified(res))
        elif( args[i] == OPTION_OFFSET or args[i] == OPTION_OFFSET_SHORT ):
            if( seenOffset ):
                error("Error. '"+ args[i] + "' option should be used only once")
                return 
            if( i+1 >= len(args)):
                error("Error. Missing offset after option '"+args[i]+"'")
                return 
            try:
                offset = int(args[i+1])
                if( offset < 0 ):
                    raise Exception()
            except:
                try:
                    offset = int(args[i+1], 16)
                    if( offset < 0 ):
                        raise Exception()
                except: 
                    error("Error. '" + args[i+1] +"' is not a valid offset")
                    return 
            i += 2 
            seenOffset = True
        elif( args[i] in [OPTION_FUNCTION, OPTION_FUNCTION_SHORT] ):
            if( not loadedBinary() ):
                error("Error. You should load a binary before building ROPChains")
                return 
            elif( seenFunction ):
                error("Option '{}' should be used only once".format(args[i]))
                return  
            userInput = ''
            i +=1
            while( i < len(args) and args[i][0] != "-"):
                userInput += args[i]
                i += 1
            (funcName, funcArgs ) = parseFunction(userInput)
            if( not funcName):
                return 
            seenFunction = True
        elif( args[i] in [OPTION_OUTPUT, OPTION_OUTPUT_SHORT]):
            if( seenOutput ):
                error("Option '{}' should be used only once".format(args[i]))
                return 
            if( i+1 >= len(args)):
                error("Error. Missing output format after option '"+args[i]+"'")
                return 
            if( args[i+1] in [OUTPUT_CONSOLE, OUTPUT_PYTHON]):
                OUTPUT = args[i+1]
                seenOutput = True
                i += 2
            else:
                error("Error. Unknown output format: {}".format(args[i+1]))
                return 
        elif( args[i] in [OPTION_HELP, OPTION_HELP_SHORT]):
            print_help()
            return 
        else:
            error("Error. Unknown option '{}'".format(args[i]))
            return 
    
    if( not funcName ):
        error("Missing function to call")
    else:
        # Set offset 
        if ( not set_offset(offset) ):
            error("Error. Your offset is too big :'(")
            return 
        try:
            call(funcName, funcArgs, constraint, assertion)
        except Exception, e:
            reset_offset()
            import sys
            raise sys.exc_info()[1], None, sys.exc_info()[2]
        # Reset offset 
        reset_offset()
            

def call(funcName, parsedArgs, constraint, assertion):
    """
    Make a syscall :D ! 
    """
    # Get target system
    if( Arch.currentBinType == Arch.BinaryType.X86_ELF ):
        syscall = Linux32.available.get(funcName)
        system = sysLinux32
    elif( Arch.currentBinType == Arch.BinaryType.X64_ELF ):
        syscall = Linux64.available.get(funcName)
        system = sysLinux64
    else:
        error("Binary type '{}' not supported yet".format(Arch.currentBinType))
        return 
    
    if( not syscall ):
        error("Syscall '{}' not supported for system '{}'".format(\
        funcName, system))
        return 
    
    if( len(parsedArgs) != syscall.nb_args()):
        error("Error. Wrong number of arguments")
        return 
        
    # Build syscall
    res = build_syscall_Linux(syscall, parsedArgs, Arch.bits(), constraint, assertion, optimizeLen=True)
    # Print result 
    if( not res ):
        print(string_bold("\n\tNo matching ROPChain found"))
    else:
        print(string_bold("\n\tFound matching ROPChain\n"))
        badBytes = constraint.getBadBytes()
        if( OUTPUT == OUTPUT_CONSOLE ):
            print(res.strConsole(Arch.bits(), badBytes))
        elif( OUTPUT == OUTPUT_PYTHON ):
            print(res.strPython(Arch.bits(), badBytes))

def listSyscalls(args):
    if( not args ):
        systems = availableSystems
    else:
        systems = []
        for arg in args:
            if( arg in availableSystems ):
                systems.append(arg)
            else:
                error("Unknown system: '{}'".format(arg))
                return 
    
    for system in list(set(systems)):
        if( system == sysLinux32 ):
            Linux32.print_available()
        elif( system == sysLinux64 ):
            Linux64.print_available()

    
